#!/usr/bin/env python
# -*- coding: UTF-8 -*-

'''
标签数据加载器
'''

from Loader import *


class LabelLoader(Loader):

    def load(self):
        '''
        加载数据文件，获取全部样本的标签向量
        :return:
        '''
        content = self.get_file_content()
        labels = []
        for index in range(self.count):
            labels.append(self.norm(content[index + 8]))
        return labels

    def norm(self, label):
        '''
        内部函数，将一个值转换为10维标签向量
        :param label:
        :return:
        '''
        label_vec = []
        label_value = self.to_int(label)
        for i in range(10):
            if i == label_value:
                label_vec.append(0.0)
            else:
                label_vec.append(0.1)
        return label_vec




